BatchNorm

对输入数组按通道执行批归一化(Batch Normalization)计算。 该算子使用给定的均值与方差对输入进行标准化,并通过 epsilon 保证数值稳定性。

\[dst_{i,j} = \frac{src_{i,j} - mean_j}{\sqrt{variance_j + \epsilon}}\]

其中:

  • \(i\) 表示第 unit 个样本

  • \(j\) 表示通道索引

  • \(mean_j\)\(variance_j\) 为第 \(j\) 个通道的统计量

对于 int8 类型输入,内部以浮点方式计算,最终结果按实现规则取整并输出为 int8

输入:
  • input - 输入数据地址,形状为 [unit, channel]

  • mean - 均值数组地址,长度为 channel

  • variance - 方差数组地址,长度为 channel

  • unit - 样本数(或展开后的空间维度)。

  • channel - 通道数。

  • epsilon - 数值稳定因子。

  • core_mask - 核掩码(仅适用于共享存储版本)。

输出:
  • output - 批归一化后的输出数据地址。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持 int8fp32 类型

  • MT7004 支持 fp16fp32 类型

  • 当前实现不包含 scale 与 bias,仅执行标准化操作

共享存储版本:

void i8_batchnorm_s(int8_t *input, int8_t *output, float *mean, float *variance, int unit, int channel, float epsilon, int core_mask)
void fp_batchnorm_s(float *input, float *output, float *mean, float *variance, int unit, int channel, float epsilon, int core_mask)
void hp_batchnorm_s(half *input, half *output, float *mean, float *variance, int unit, int channel, float epsilon, int core_mask)

C调用示例:

 1// FT78NE 示例
 2#include <stdio.h>
 3#include <batchnorm.h>
 4
 5int main(int argc, char* argv[]) {
 6    int8_t *input  = (int8_t *)0xA0000000;   // input 在 DDR 空间
 7    int8_t *output = (int8_t *)0xC0000000;
 8    float *mean    = (float *)0xA1000000;
 9    float *var     = (float *)0xA2000000;
10    int unit = 128;
11    int channel = 64;
12    float epsilon = 1e-5f;
13    int core_mask = 0xff;
14
15    i8_batchnorm_s(input, output, mean, var, unit, channel, epsilon, core_mask);
16    return 0;
17}

私有存储版本:

void i8_batchnorm_p(int8_t *input, int8_t *output, float *mean, float *variance, int unit, int channel, float epsilon)
void fp_batchnorm_p(float *input, float *output, float *mean, float *variance, int unit, int channel, float epsilon)
void hp_batchnorm_p(half *input, half *output, float *mean, float *variance, int unit, int channel, float epsilon)

C调用示例:

 1// FT78NE 示例
 2#include <stdio.h>
 3#include <batchnorm.h>
 4
 5int main(int argc, char* argv[]) {
 6    int8_t *input  = (int8_t *)0x10810000;   // input 在 L2 空间
 7    int8_t *output = (int8_t *)0x10820000;
 8    float *mean    = (float *)0x10830000;
 9    float *var     = (float *)0x10840000;
10    int unit = 128;
11    int channel = 64;
12    float epsilon = 1e-5f;
13
14    i8_batchnorm_p(input, output, mean, var, unit, channel, epsilon);
15    return 0;
16}